import networkx as nx
import argparse
import time
import os


class IpGen():
    def __init__(self, prefix=24):
        self.prefix = prefix
        self.network = (1 << prefix) - 1
        self.m = 1
        self.n = 1
        self.o = 1

    def gen(self):
        ip = self.network << (32 - self.prefix)
        self.network -= 1
        return ip


def gen_topo(npod):
    topo = nx.Graph()
    for iPod in range(npod):
        for iFsw in range(4):
            for iRsw in range(48):
                src = 'fsw-%d-%d' % (iPod, iFsw)
                dst = 'rsw-%d-%d' % (iPod, iRsw)
                topo.add_edge(src, dst)
            for iSsw in range(48):
                topo.add_edge('fsw-%d-%d' % (iPod, iFsw),
                              'ssw-%d-%d' % (iFsw, iSsw))

    print(len(topo.nodes), len(topo.edges))
    # print(len([p for p in nx.all_shortest_paths(topo, 'rsw-0-0', 'rsw-11-11')]))
    return topo


def dump_topo(topo, output):
    fn = output + '/topo.txt'
    os.makedirs(os.path.dirname(fn), exist_ok=True)
    with open(fn, 'w') as f:
        for edge in topo.edges:
            str = '%s %s %s %s\n' % (edge[0], edge[1], edge[1], edge[0])
            f.write(str)


def gen_fib(topo, npod):
    FIBs = {}
    for n in topo.nodes:
        FIBs[n] = []
        
    for iPod in range(npod):
        for iRsw in range(48):
            rsw_name = 'rsw-%d-%d' % (iPod, iRsw)
            ip = (iPod << 24) + (iRsw << 16)
            for jPod in range(npod):
                for jRsw in range(48):
                    src_rsw_name = 'rsw-%d-%d' % (jPod, jRsw)
                    if rsw_name != src_rsw_name:
                        FIBs[src_rsw_name].append(
                            [ip, 16, [('fsw-%d-%d' % (jPod, x)) for x in range(4)]])
                for jFsw in range(4):
                    src_fsw_name = 'fsw-%d-%d' % (jPod, jFsw)
                    if jPod == iPod:
                        FIBs[src_fsw_name].append([ip, 16, [rsw_name]])
                    else:
                        FIBs[src_fsw_name].append([ip, 16, [('ssw-%d-%d' % (jFsw, x)) for x in range(48)]])

            for jSpine in range(4):
                for jSsw in range(48):
                    src_ssw_name = 'ssw-%d-%d' % (jSpine, jSsw)
                    FIBs[src_ssw_name].append([ip, 16, ['fsw-%d-%d' % (iPod, jSpine)]])
    return FIBs

def dump_FIB(FIBs, output):
    for k, v in FIBs.items():
        fn = output + '/rules/' + k
        os.makedirs(os.path.dirname(fn), exist_ok=True)
        with open(fn, 'w') as f:
            for r in v:
                str = 'fw %d %d %s\n' % (r[0], r[1], ','.join(r[2]))
                f.write(str)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="The FIB format is: fw ip prefix outports, read as \"a node has a rule ip/prefix that forward to outports\"")
    # parser.add_argument("input", help="the input topology file")
    parser.add_argument("output", help="the output file")
    parser.add_argument('-f', action="store_true", help='generate FIB')
    parser.add_argument("-k", type=int, default=112, help="number of pods")
    parser.add_argument("-nprefix", type=int, default=1,
                        help="the number of prefixes on each node, default=1")
    parser.add_argument("-prefix", type=int, default=24,
                        help="the prefix for each address, default=24")
    args = parser.parse_args()
    topo = gen_topo(args.k)
    dump_topo(topo, args.output)
    if args.f:
        print(time.time())
        fib = gen_fib(topo, args.k)
        print(time.time())
        dump_FIB(fib, args.output)
